{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/jiunyiyang/opt/anaconda3/envs/python3.7/lib/python3.7/site-packages/tqdm/std.py:697: FutureWarning: The Panel class is removed from pandas. Accessing it from the top-level namespace will also be removed in the next version\n",
      "  from pandas import Panel\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from IPython.display import display\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV, cross_val_score\n",
    "from sklearn.metrics import confusion_matrix, classification_report, precision_score\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from tqdm.notebook import tqdm\n",
    "tqdm.pandas()\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### (1)  Import Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sepal length (cm)</th>\n",
       "      <th>sepal width (cm)</th>\n",
       "      <th>petal length (cm)</th>\n",
       "      <th>petal width (cm)</th>\n",
       "      <th>target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>5.1</td>\n",
       "      <td>3.5</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>4.9</td>\n",
       "      <td>3.0</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>4.7</td>\n",
       "      <td>3.2</td>\n",
       "      <td>1.3</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4.6</td>\n",
       "      <td>3.1</td>\n",
       "      <td>1.5</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5.0</td>\n",
       "      <td>3.6</td>\n",
       "      <td>1.4</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>145</th>\n",
       "      <td>6.7</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.2</td>\n",
       "      <td>2.3</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>146</th>\n",
       "      <td>6.3</td>\n",
       "      <td>2.5</td>\n",
       "      <td>5.0</td>\n",
       "      <td>1.9</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>147</th>\n",
       "      <td>6.5</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.2</td>\n",
       "      <td>2.0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>148</th>\n",
       "      <td>6.2</td>\n",
       "      <td>3.4</td>\n",
       "      <td>5.4</td>\n",
       "      <td>2.3</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>149</th>\n",
       "      <td>5.9</td>\n",
       "      <td>3.0</td>\n",
       "      <td>5.1</td>\n",
       "      <td>1.8</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>150 rows × 5 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     sepal length (cm)  sepal width (cm)  petal length (cm)  petal width (cm)  \\\n",
       "0                  5.1               3.5                1.4               0.2   \n",
       "1                  4.9               3.0                1.4               0.2   \n",
       "2                  4.7               3.2                1.3               0.2   \n",
       "3                  4.6               3.1                1.5               0.2   \n",
       "4                  5.0               3.6                1.4               0.2   \n",
       "..                 ...               ...                ...               ...   \n",
       "145                6.7               3.0                5.2               2.3   \n",
       "146                6.3               2.5                5.0               1.9   \n",
       "147                6.5               3.0                5.2               2.0   \n",
       "148                6.2               3.4                5.4               2.3   \n",
       "149                5.9               3.0                5.1               1.8   \n",
       "\n",
       "     target  \n",
       "0         0  \n",
       "1         0  \n",
       "2         0  \n",
       "3         0  \n",
       "4         0  \n",
       "..      ...  \n",
       "145       2  \n",
       "146       2  \n",
       "147       2  \n",
       "148       2  \n",
       "149       2  \n",
       "\n",
       "[150 rows x 5 columns]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "df = pd.DataFrame(load_iris()['data'], columns=load_iris()['feature_names'])\n",
    "df = pd.concat([df, pd.DataFrame(load_iris()['target'], columns=['target'])], axis=1)\n",
    "target_names = load_iris()['target_names']\n",
    "display(df)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### (2)  Train/Test Sets Split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "shape of X_train: (105, 3)\n",
      "shape of X_test: (45, 3)\n",
      "shape of y_train: (105, 1)\n",
      "shape of y_test: (45, 1)\n"
     ]
    }
   ],
   "source": [
    "X, y = df.iloc[:, :3], df.iloc[:, 4:5]\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)\n",
    "\n",
    "print(f'shape of X_train: {X_train.shape}')\n",
    "print(f'shape of X_test: {X_test.shape}')\n",
    "print(f'shape of y_train: {y_train.shape}')\n",
    "print(f'shape of y_test: {y_test.shape}')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### (3)  Scaling\n",
    "- standardlization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1.58567306 1.25646135 1.30924548]\n",
      "[-0.28637946 -0.52635543  0.63232742]\n"
     ]
    }
   ],
   "source": [
    "scaler = StandardScaler()\n",
    "scaler.fit(X_train)\n",
    "X_train_scaled = scaler.transform(X_train)\n",
    "X_test_scaled = scaler.transform(X_test)\n",
    "\n",
    "print(X_train_scaled[0])\n",
    "print(X_test_scaled[0])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### (4)  Grid Search for Parameters\n",
    "- using cross validation in grid search on training set\n",
    "- using AUC score for evalutation\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [],
   "source": [
    "l1_params = {\n",
    "    'solver': ['liblinear'], \n",
    "    'C': np.linspace(0.05, 1, 19)\n",
    "}\n",
    "\n",
    "l2_params = {\n",
    "    'solver': ['liblinear', 'Ibfgs', 'newton-cg', 'sag'], \n",
    "    'C': np.linspace(0.05, 1, 19)\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best score of L1 penalty classifier:  0.8857142857142856\n",
      "Best parameters set: {'C': 0.8416666666666667, 'solver': 'liblinear'}\n",
      "Best score of L2 penalty classifier:  0.9047619047619048\n",
      "Best parameters set: {'C': 0.788888888888889, 'solver': 'newton-cg'}\n"
     ]
    }
   ],
   "source": [
    "kfold_5 = StratifiedKFold(n_splits=5)\n",
    "\n",
    "lr_l1 = LogisticRegression(penalty='l1', solver='liblinear', max_iter=1000)\n",
    "clf_l1 = GridSearchCV(estimator=lr_l1, param_grid=l1_params, scoring='accuracy', cv=kfold_5)\n",
    "clf_l1.fit(X_train_scaled, y_train)\n",
    "print('Best score of L1 penalty classifier: ', clf_l1.best_score_)\n",
    "print(f'Best parameters set: {clf_l1.best_params_}')\n",
    "best_clf_l1 = clf_l1.best_estimator_\n",
    "\n",
    "lr_l2 = LogisticRegression(penalty='l2', max_iter=1000)\n",
    "clf_l2 = GridSearchCV(estimator=lr_l2, param_grid=l2_params, scoring='accuracy', cv=kfold_5)\n",
    "clf_l2.fit(X_train_scaled, y_train)\n",
    "print('Best score of L2 penalty classifier: ', clf_l2.best_score_)\n",
    "print(f'Best parameters set: {clf_l2.best_params_}')\n",
    "best_clf_l2 = clf_l2.best_estimator_\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> Best parameters: {\n",
    "    'penalty': 'l2', 'C': 0.788888888888889, 'solver': 'newton-cg'\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### (5)(6)  Apply on testing set\n",
    "- confusion matrix\n",
    "- classification report\n",
    "- precision score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[16,  0,  0],\n",
       "       [ 0, 14,  1],\n",
       "       [ 0,  0, 14]])"
      ]
     },
     "execution_count": 99,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "best_clf_l2.fit(X_train_scaled, y_train)\n",
    "Y_pred = best_clf_l2.predict(X_test_scaled)\n",
    "confusion_matrix(y_true=y_test, y_pred=Y_pred)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       1.00      1.00      1.00        16\n",
      "           1       1.00      0.93      0.97        15\n",
      "           2       0.93      1.00      0.97        14\n",
      "\n",
      "    accuracy                           0.98        45\n",
      "   macro avg       0.98      0.98      0.98        45\n",
      "weighted avg       0.98      0.98      0.98        45\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(classification_report(y_true=y_test, y_pred=Y_pred))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9777777777777777"
      ]
     },
     "execution_count": 101,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "precision_score(y_true=y_test, y_pred=Y_pred, average='micro')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
